"""Protein relax pipeline
1. Usage:
$ python3 protein_relax.py -i examples/protein/case2.pdb -o examples/protein/case2-optimized.pdb
"""
import os

os.environ["GLOG_v"] = "3"
import numpy as np
from mindspore import context, Tensor, nn
from mindspore.nn import CellList
from mindspore import numpy as msnp
import mindspore as ms
import argparse
from mindsponge import Sponge
from mindsponge.callback import RunInfo
from mindsponge import set_global_units
from mindsponge import Protein
from mindsponge import ForceField
from mindsponge.optimizer import SteepestDescent
from mindsponge.potential.bias import OscillatorBias
from mindsponge.system.modeling.pdb_generator import gen_pdb

from mindsponge.common.utils import get_pdb_info
from mindsponge.common import residue_constants
from mindsponge.metrics.structure_violations import get_structural_violations

parser = argparse.ArgumentParser()
parser.add_argument("-i", help="Set the input pdb file path.")
parser.add_argument("-o", help="Set the output pdb file path.")
parser.add_argument(
    "-addh", help="Set to 1 if need to add H atoms, default to be 1..", default=1
)
args = parser.parse_args()
pdb_name = args.i
save_pdb_name = args.o
addh = args.addh
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=1)

VIOLATION_TOLERANCE_ACTOR = 12.0
CLASH_OVERLAP_TOLERANCE = 1.5
C_ONE_HOT = nn.OneHot(depth=14)(Tensor(2, ms.int32))
N_ONE_HOT = nn.OneHot(depth=14)(Tensor(0, ms.int32))
DISTS_MASK_I = msnp.eye(14, 14)
CYS_SG_IDX = Tensor(5, ms.int32)
ATOMTYPE_RADIUS = Tensor(
    np.array(
        [
            1.55,
            1.7,
            1.7,
            1.7,
            1.52,
            1.7,
            1.7,
            1.7,
            1.52,
            1.52,
            1.8,
            1.7,
            1.7,
            1.7,
            1.55,
            1.55,
            1.52,
            1.52,
            1.8,
            1.7,
            1.7,
            1.7,
            1.7,
            1.55,
            1.55,
            1.55,
            1.52,
            1.52,
            1.7,
            1.55,
            1.55,
            1.52,
            1.7,
            1.7,
            1.7,
            1.55,
            1.52,
        ]
    ),
    ms.float32,
)
(
    LOWER_BOUND,
    UPPER_BOUND,
    RESTYPE_ATOM14_BOUND_STDDEV,
) = residue_constants.make_atom14_dists_bounds(
    overlap_tolerance=1.5, bond_length_tolerance_factor=12.0
)
LOWER_BOUND = Tensor(LOWER_BOUND, ms.float32)
UPPER_BOUND = Tensor(UPPER_BOUND, ms.float32)
RESTYPE_ATOM14_BOUND_STDDEV = Tensor(RESTYPE_ATOM14_BOUND_STDDEV, ms.float32)


def get_violation_loss(system):
    """ Package the violation loss calculation module. """
    gen_pdb(
        system.coordinate.asnumpy(),
        system.atom_name[0],
        system.init_resname,
        system.init_resid,
        pdb_name=save_pdb_name,
    )
    features = get_pdb_info(save_pdb_name)
    atom14_atom_exists_t = Tensor(features.get("atom14_gt_exists")).astype(
        ms.float32
    )
    residue_index_t = Tensor(features.get("residue_index")).astype(ms.float32)
    residx_atom14_to_atom37_t = Tensor(
        features.get("residx_atom14_to_atom37")
    ).astype(ms.int32)
    atom14_positions_t = Tensor(features.get("atom14_gt_positions")).astype(
        ms.float32
    )
    aatype_t = Tensor(features.get("aatype")).astype(ms.int32)
    violations = get_structural_violations(
        atom14_atom_exists_t,
        residue_index_t,
        aatype_t,
        residx_atom14_to_atom37_t,
        atom14_positions_t,
        VIOLATION_TOLERANCE_ACTOR,
        CLASH_OVERLAP_TOLERANCE,
        LOWER_BOUND,
        UPPER_BOUND,
        ATOMTYPE_RADIUS,
        C_ONE_HOT,
        N_ONE_HOT,
        DISTS_MASK_I,
        CYS_SG_IDX,
    )
    return violations


def optimize_strategy(system, energy, gds, loops, ads, adm, nonh_mask, mode=1):
    """ The optimize strategy including 3 modes.
    Args:
        system(Molecule): The given Molecule object.
        gds(int): Optimize steps while using Gradient Descent.
        loops(int): The number of loops to use different optimizers.
        ads(int): The optimize steps of using Adam.
        adm(int): The repeat number of using Adam in each loop.
        nonh_mask(bool): The mask of Hydrogen atoms. For atom whose atomic number > 1 would be labeled as 1.
        mode(int): The optimize mode, for now only mode = 1, 2, 3 are supported.
            mode == 1: Use the hybrid optimize strategy which includes total energy and bonded energy.
            mode == 2: Use the total energy only.
            mode == 3: Use the bonded energy only.
    """
    learning_rate = 1e-07
    factor = 1.003
    opt = SteepestDescent(
        system.trainable_params(),
        learning_rate=learning_rate,
        factor=factor,
        nonh_mask=nonh_mask,
    )
    for i, param in enumerate(opt.trainable_params()):
        print(i, param.name, param.shape)

    md = Sponge(system, energy, opt)
    run_info = RunInfo(10)
    md.run(gds, callbacks=[run_info])

    if msnp.isnan(md.energy().sum()):
        return 0

    for _ in range(loops):
        k_coe = 10
        harmonic_energy = OscillatorBias(1 * system.coordinate, k_coe, nonh_mask)
        md.sim_system.bias = CellList([harmonic_energy])
        learning_rate = 5e-02

        if mode in (1, 2):
            energy.set_energy_scale([1, 1, 1, 1, 1, 1])
            md.change_potential(energy)

            for _ in range(adm):
                opt = nn.Adam(system.trainable_params(), learning_rate=learning_rate)
                for i, param in enumerate(opt.trainable_params()):
                    print(i, param.name, param.shape)
                md.change_optimizer(opt)
                print(md.energy())
                run_info = RunInfo(10)
                md.run(ads, callbacks=[run_info])
                if msnp.isnan(md.energy().sum()):
                    return 0

        if mode in (1, 3):
            energy.set_energy_scale([1, 1, 1, 0, 0, 0])
            md.change_potential(energy)

            for _ in range(adm):
                opt = nn.Adam(system.trainable_params(), learning_rate=learning_rate)
                for i, param in enumerate(opt.trainable_params()):
                    print(i, param.name, param.shape)
                md.change_optimizer(opt)
                print(md.energy())
                run_info = RunInfo(10)
                md.run(ads, callbacks=[run_info])
                if msnp.isnan(md.energy().sum()):
                    return 0

    return system


def main():
    seed = 2333
    ms.set_seed(seed)
    set_global_units("A", "kcal/mol")
    static_system = Protein(pdb=pdb_name)
    energy = ForceField(static_system, "AMBER.FF14SB")
    nonh_mask = Tensor(
        np.where(static_system.atomic_number[0] > 1, 0, 1)[None, :, None], ms.int32
    )

    try:
        violations = get_violation_loss(static_system)
        violation_loss = violations[-1]
        print("The first try violation loss value is: {}".format(violation_loss))

    except AttributeError:
        import traceback
        traceback.print_exc()

    gds, loops, ads, adm = 100, 3, 200, 2
    system = optimize_strategy(static_system, energy, gds, loops, ads, adm, nonh_mask, mode=1)

    try:
        violations = get_violation_loss(system)
        violation_loss = violations[-1]
        print("The first try violation loss value is: {}".format(violation_loss))

    except AttributeError:
        import traceback
        traceback.print_exc()

    while system == 0:
        gds = int(0.5 * gds)
        ads = int(0.8 * ads)
        system = optimize_strategy(static_system, energy, gds, loops, ads, adm, nonh_mask, mode=1)
        try:
            violations = get_violation_loss(system)
            violation_loss = violations[-1]
            print("The second try violation loss value is: {}".format(violation_loss))
        except AttributeError:
            continue

    if violation_loss > 0:
        gds = 200
        loops, ads, adm = 6, 200, 1
        system = optimize_strategy(static_system, energy, gds, loops, ads, adm, nonh_mask, mode=2)

        violations = get_violation_loss(system)
        violation_loss = violations[-1]
        print("The third try violation loss value is: {}".format(violation_loss))

    if violation_loss > 0:
        gds = 200
        loops, ads, adm = 6, 200, 1
        system = optimize_strategy(static_system, energy, gds, loops, ads, adm, nonh_mask, mode=3)

        violations = get_violation_loss(system)
        violation_loss = violations[-1]
        print("The forth try violation loss value is: {}".format(violation_loss))

    if violation_loss > 0:
        gds = 100
        loops, ads, adm = 8, 100, 1
        system = optimize_strategy(static_system, energy, gds, loops, ads, adm, nonh_mask, mode=3)

        violations = get_violation_loss(system)
        violation_loss = violations[-1]
        print("The fifth try violation loss value is: {}".format(violation_loss))

    if violation_loss > 0:
        gds, loops, ads, adm = 30, 2, 150, 2
        system = optimize_strategy(static_system, energy, gds, loops, ads, adm, nonh_mask, mode=1)

        violations = get_violation_loss(system)
        violation_loss = violations[-1]
        print("The final try violation loss value is: {}".format(violation_loss))


main()
